import json
import matplotlib.pyplot as plt

# 读取日志
with open('training_log.json', 'r') as f:
    history = json.load(f)

epochs = list(range(1, len(history['train_loss']) + 1))

# 画loss曲线
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs, history['train_loss'], marker='o', color='blue')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

# 画accuracy曲线
plt.subplot(1, 2, 2)
plt.plot(epochs, history['test_acc'], marker='o', color='green')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')

plt.tight_layout()
plt.savefig('training_curves.png')
plt.show()
